/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef EXECUTE_GRAPH_INFER_SHAPE_MAT_MUL_H
#define EXECUTE_GRAPH_INFER_SHAPE_MAT_MUL_H
#include <vector>
#include <string>
#include "shape.h"
#include "definitions.h"
#include "operator.h"

using std::vector;
using std::string;

namespace ops {
class InferShapeMatMul {
 public:
  bool GetShapeRangeOfOutput();
  InferShapeMatMul(const char *op_name,
                   const ge::Shape &shape_a, const ge::Shape &shape_b, const ge::Shape &shape_bias,
                   ge::Range &range_a, ge::Range &range_b, ge::Range &range_bias,
                   bool trans_a, bool trans_b,
                   ge::Shape &shape_out, ge::Range &range_out,
                   bool has_batch);

 private:
  void NormalizeShapeAndRange();
  bool InferMKN();
  bool InferBatch();
  void SimplifyShapeAndRange();
  bool IsStaticShape();

  static const int64_t base_len;
  const char *op_name;
  const ge::Shape &shape_a;
  const ge::Shape &shape_b;
  const ge::Shape &shape_bias;
  ge::Range &range_a;
  ge::Range &range_b;
  ge::Range &range_bias;
  bool trans_a;
  bool trans_b;
  ge::Shape &shape_out;
  ge::Range &range_out;
  bool has_batch;
  int64_t num_dim;

  vector<int64_t> infer_shape_a;
  vector<int64_t> infer_shape_b;
  vector<int64_t> infer_shape_bias;
  vector<std::pair<int64_t, int64_t>> infer_range_a;
  vector<std::pair<int64_t, int64_t>> infer_range_b;
  vector<std::pair<int64_t, int64_t>> infer_range_bias;
};

ge::Status MatMulInferShape(const ge::Operator &op);
}

#endif  //EXECUTE_GRAPH_INFER_SHAPE_MAT_MUL_H
